# Natural Language Toolkit: Interface to TADM Classifier
#
# Copyright (C) 2008 NLTK Project 
# Author: Joseph Frazee <jfrazee@mail.utexas.edu>
# URL: <http://www.nltk.org/>
# For license information, see LICENSE.TXT

import sys
import subprocess

from nltk.internals import find_binary
try:
    import numpy
except ImportError:
    numpy = None

_tadm_bin = None
def config_tadm(bin=None):
    global _tadm_bin
    _tadm_bin = find_binary(
        'tadm', bin,
        env_vars=['TADM_DIR'],
        binary_names=['tadm'],
        url='http://tadm.sf.net')
        
def write_tadm_file(train_toks, encoding, stream):
    """
    Generate an input file for C{tadm} based on the given corpus of
    classified tokens.

    @type train_toks: C{list} of C{tuples} of (C{dict}, C{str})
    @param train_toks: Training data, represented as a list of
        pairs, the first member of which is a feature dictionary,
        and the second of which is a classification label.

    @type encoding: L{TadmEventMaxentFeatureEncoding}
    @param encoding: A feature encoding, used to convert featuresets
        into feature vectors.

    @type stream: C{stream}
    @param stream: The stream to which the C{tadm} input file should be
        written.
    """
    # See the following for a file format description:
    #
    # http://sf.net/forum/forum.php?thread_id=1391502&forum_id=473054
    # http://sf.net/forum/forum.php?thread_id=1675097&forum_id=473054
    labels = encoding.labels()
    for featureset, label in train_toks:
        stream.write('%d\n' % len(labels))
        for known_label in labels:
            v = encoding.encode(featureset, known_label)
            stream.write('%d %d %s\n' % (int(label == known_label), len(v), 
                ' '.join('%d %d' % u for u in v)))

def parse_tadm_weights(paramfile):
    """
    Given the stdout output generated by C{tadm} when training a
    model, return a C{numpy} array containing the corresponding weight
    vector.
    """
    weights = []
    for line in paramfile:
        weights.append(float(line.strip()))
    return numpy.array(weights, 'd')   

def call_tadm(args):
    """
    Call the C{tadm} binary with the given arguments.
    """
    if isinstance(args, basestring):
        raise TypeError('args should be a list of strings')
    if _tadm_bin is None:
        config_tadm()

    # Call tadm via a subprocess
    cmd = [_tadm_bin] + args
    p = subprocess.Popen(cmd, stdout=sys.stdout)
    (stdout, stderr) = p.communicate()

    # Check the return code.
    if p.returncode != 0:
        print
        print stderr
        raise OSError('tadm command failed!')

def names_demo():
    from nltk.classify.util import names_demo
    from nltk.classify.maxent import TadmMaxentClassifier
    classifier = names_demo(TadmMaxentClassifier.train)               

def encoding_demo():
    import sys
    from nltk.classify.maxent import TadmEventMaxentFeatureEncoding
    from nltk.classify.tadm import write_tadm_file
    tokens = [({'f0':1, 'f1':1, 'f3':1}, 'A'),
              ({'f0':1, 'f2':1, 'f4':1}, 'B'),
              ({'f0':2, 'f2':1, 'f3':1, 'f4':1}, 'A')]
    encoding = TadmEventMaxentFeatureEncoding.train(tokens)
    write_tadm_file(tokens, encoding, sys.stdout)
    print
    for i in range(encoding.length()):
        print '%s --> %d' % (encoding.describe(i), i)
    print
        
if __name__ == '__main__':
    encoding_demo()
    names_demo()
